import torch
import random
import torch.nn as nn
from image_synthesis.modeling.embeddings.base_embedding import BaseEmbedding
from image_synthesis.modeling.utils.misc import mask_with_top_k

class SimpleEmbedding(BaseEmbedding):
    def __init__(self, 
                 num_embed, 
                 embed_dim, 
                 num_pos_embed=0,
                 trainable=True,
                 pos_emb_type='embedding',
                 mask_embedding=False,
                 emb_orthogonal_loss=False, # used to make the embedding orthogonal with each other
        ):
        super().__init__()
        
        self.num_embed = num_embed
        self.embed_dim = embed_dim
        self.num_pos_embed = num_pos_embed
        self.trainable = trainable
        self.pos_emb_type = pos_emb_type
        self.mask_embedding = mask_embedding
        self.emb_orthogonal_loss = emb_orthogonal_loss

        self.emb = nn.Embedding(num_embed, embed_dim)
        if self.num_pos_embed > 0:
            if self.pos_emb_type == 'embedding':
                self.pos_emb = nn.Embedding(self.num_pos_embed, embed_dim)
            elif self.pos_emb_type == 'parameter':
                self.pos_emb = nn.Parameter(torch.zeros(1, self.num_pos_embed, embed_dim))
        else:
            self.pos_emb = None

        self._set_trainable()


    def get_loss(self):
        if self.trainable:
            if self.orthogonal_loss:
                cov = torch.einsum('md,nd->mn', self.emb.weight, self.emb.weight) # N x N
                norm = torch.norm(self.emb.weight, p=2, dim=1, keepdim=True) # N x 1
                norm = norm.permute(1, 0) * norm # N x N
                cov = cov / norm
                cov = cov * (1 - torch.eye(cov.shape[0], cov.shape[1]).to(cov))
                emb_cov_loss = cov.abs().sum() / (cov.shape[0] * cov.shape[1] - min(cov.shape[0], cov.shape[1]))
                return emb_cov_loss
        return None


    def forward(self, index, mask=None, **kwargs):
        """
        index: B x L 
        mask: B x L, True for unmasked token, False for masked token
        """
        assert index.dim() == 2 # B x L
        try:
            index[index < 0] = 0  # some padded token maybe negative, so set them to 0
            emb = self.emb(index) # B x L x D
        except:
            raise RuntimeError('IndexError: index out of range, max index {}, num embed {}'.format(index.max(), self.num_embed))
        
        if self.mask_embedding:
            emb = emb * mask.unsqueeze(-1).to(emb)

        if self.pos_emb is not None:
            if self.pos_emb_type == 'embedding':
                pos_emb = self.pos_emb(torch.arange(index.shape[1], device=index.device).view(1, index.shape[1])) # 1 x L x D
            else:
                pos_emb = self.pos_emb[:, :emb.shape[1], :]
            emb += pos_emb

        return emb
